Skip to content

Commit 72d00a8

Browse files
wang2yn84The tunix Authors
authored andcommitted
Copybara import of the project:
-- a407ff2 by Lance Wang <lancewang@google.com>: Add vLLM to dependency list since it's OSS-ed. COPYBARA_INTEGRATE_REVIEW=#582 from google:lance-add-vllm a407ff2 PiperOrigin-RevId: 821078416
1 parent c55e233 commit 72d00a8

File tree

3 files changed

+83
-19
lines changed

3 files changed

+83
-19
lines changed

.github/workflows/tpu-tests.yml

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ concurrency:
2929
group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
3030
cancel-in-progress: true
3131

32+
env:
33+
HF_HOME: ~/.cache/huggingface
34+
HF_HUB_ENABLE_HF_TRANSFER: "1"
35+
3236
jobs:
33-
run:
37+
run_prod:
3438
runs-on: [linux-x86-ct5lp-224-8tpu]
3539
environment: testing
3640
container:
@@ -40,15 +44,25 @@ jobs:
4044
CLOUD_TPU_ACCELERATOR: v5e-8
4145
JAX_PLATFORMS: tpu
4246
steps:
47+
48+
# Cache Hugging Face hub
49+
- name: Cache HF hub
50+
uses: actions/cache@v4
51+
with:
52+
path: ~/.cache/huggingface
53+
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
54+
restore-keys: |
55+
hf-${{ runner.os }}-
56+
4357
- name: Checkout code
4458
uses: actions/checkout@v4
4559
with:
4660
fetch-depth: 0
4761

4862
- name: Install tunix dependencies
4963
run: |
50-
pip install -e .
51-
pip install pytest pytest-xdist jinja2
64+
pip install -e .[prod]
65+
pip install pytest pytest-xdist
5266
5367
- name: Verify TPU availability
5468
run: |
@@ -85,7 +99,6 @@ jobs:
8599
86100
- name: Run tunix generation tests (PASSED only)
87101
run: |
88-
# vllm_sampler_test depends on vllm TPU which is not OSS yet
89102
# tokenizer_adapter_test requires access to gated repo
90103
python -m pytest tests/generate/ -v --tb=short \
91104
--ignore=tests/generate/vllm_sampler_test.py \
@@ -94,7 +107,7 @@ jobs:
94107
- name: Run tunix SFT tests
95108
run: |
96109
python -m pytest tests/sft/ -v --tb=short
97-
110+
98111
- name: Run tunix SFT integration tests
99112
env:
100113
HF_TOKEN: ${{ secrets.HF_TOKEN }}
@@ -115,28 +128,28 @@ jobs:
115128
env:
116129
HF_TOKEN: ${{ secrets.HF_TOKEN }}
117130
run: |
118-
131+
119132
# Download GSM8K dataset
120133
mkdir -p /tmp/grpo_test/rl/grpo/data
121134
python3 -c "
122135
from datasets import load_dataset
123136
import json
124-
137+
125138
# Download and save GSM8K train split
126139
dataset = load_dataset('openai/gsm8k', 'main', split='train')
127140
train_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
128141
with open('/tmp/grpo_test/rl/grpo/data/gsm8k_train.json', 'w') as f:
129142
json.dump(train_data, f)
130-
143+
131144
# Download and save GSM8K test split
132145
dataset = load_dataset('openai/gsm8k', 'main', split='test')
133146
test_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
134147
with open('/tmp/grpo_test/rl/grpo/data/gsm8k_test.json', 'w') as f:
135148
json.dump(test_data, f)
136-
149+
137150
print('GSM8K dataset downloaded successfully')
138151
"
139-
152+
140153
# Run GRPO demo script with minimal configuration
141154
python3 scripts/grpo_demo_llama3_qwen2.py \
142155
--root-dir=/tmp/grpo_test \
@@ -156,3 +169,48 @@ jobs:
156169
exit "${code:-0}"
157170
fi
158171
172+
run_dev:
173+
runs-on: [linux-x86-ct5lp-224-8tpu]
174+
environment: testing
175+
container:
176+
image: vllm/vllm-tpu:v0.11.1
177+
options: --privileged
178+
env:
179+
CLOUD_TPU_ACCELERATOR: v5e-8
180+
JAX_PLATFORMS: tpu
181+
steps:
182+
# Cache Hugging Face hub
183+
- name: Cache HF hub
184+
uses: actions/cache@v4
185+
with:
186+
path: ~/.cache/huggingface
187+
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
188+
restore-keys: |
189+
hf-${{ runner.os }}-
190+
191+
- name: Checkout code
192+
uses: actions/checkout@v4
193+
with:
194+
fetch-depth: 0
195+
196+
- name: Setup Tunix and tpu-inference
197+
run: |
198+
echo "Current directory:"
199+
pwd
200+
pip install --upgrade pip setuptools wheel
201+
202+
# Install Tunix
203+
pip uninstall torch torch-xla libtpu jax jaxlib -y
204+
pip install -e .[dev]
205+
206+
# Install tpu-inference
207+
pip uninstall torch libtpu jax jaxlib -y
208+
pip install tpu-inference==v0.11.1 --force-reinstall
209+
pip install pytest pytest-xdist
210+
211+
- name: Run tests
212+
env:
213+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
214+
run: |
215+
pytest tests/generate/vllm_sampler_test.py -v --tb=short
216+

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,19 @@ pip install git+https://github.com/google/tunix
8383
```
8484

8585
3. From source (editable install) If you plan to modify the codebase and run it
86-
in development mode:
86+
in development mode. If you'd like to install vllm, the tpu-inference
87+
supported version is not released yet, please follow the instructions to
88+
install manually
89+
(https://docs.vllm.ai/en/latest/getting_started/installation/google_tpu.html)
90+
or download the docker image (vllm/vllm-tpu:v0.11.1) then
91+
`pip install tpu-inference` for TPU backend:
8792

8893
```sh
8994
git clone https://github.com/google/tunix.git
9095
cd tunix
9196
pip install -e ".[dev]"
9297

98+
# Then install vLLM and tpu-inference
9399
```
94100

95101
## Getting Started

pyproject.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,23 @@ classifiers = [
1919
]
2020
dependencies = [
2121
"datasets",
22+
"flax>=0.11.1",
2223
"gcsfs",
2324
"grain",
2425
"huggingface_hub",
25-
"jax[tpu]>=0.6.0,!=0.7.2", # Jax 0.7.2 has performance regression on OSS
2626
"jaxtyping",
27+
"jinja2", # Huggingface chat template
2728
"kagglehub",
28-
"omegaconf",
29+
"numba",
30+
"omegaconf", # CLI config
31+
"python-dotenv", # Huggingface API key
2932
"qwix",
3033
"sentencepiece",
3134
"tensorboardX",
3235
"tensorflow_datasets",
3336
"tqdm",
3437
"transformers",
35-
"python-dotenv",
36-
"jinja2",
38+
"hf_transfer", # Huggingface caching in CI
3739
]
3840

3941
[project.optional-dependencies]
@@ -49,12 +51,10 @@ docs = [
4951
"sphinx_contributors",
5052
]
5153
prod = [
52-
"flax>=0.11.2",
54+
"jax[tpu]>=0.6.0,!=0.7.2", # Jax 0.7.2 has performance regression on OSS
5355
]
5456
dev = [
55-
"flax>=0.11.2",
56-
"numba",
57-
"vllm",
57+
# Manully install vLLM & tpu-inferece, which depends on jax[tpu]==0.7.2
5858
]
5959

6060
[project.urls]

0 commit comments

Comments
 (0)