@@ -29,8 +29,12 @@ concurrency:
2929 group : ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
3030 cancel-in-progress : true
3131
32+ env :
33+ HF_HOME : ~/.cache/huggingface
34+ HF_HUB_ENABLE_HF_TRANSFER : " 1"
35+
3236jobs :
33- run :
37+ run_prod :
3438 runs-on : [linux-x86-ct5lp-224-8tpu]
3539 environment : testing
3640 container :
@@ -40,15 +44,25 @@ jobs:
4044 CLOUD_TPU_ACCELERATOR : v5e-8
4145 JAX_PLATFORMS : tpu
4246 steps :
47+
48+ # Cache Hugging Face hub
49+ - name : Cache HF hub
50+ uses : actions/cache@v4
51+ with :
52+ path : ~/.cache/huggingface
53+ key : hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
54+ restore-keys : |
55+ hf-${{ runner.os }}-
56+
4357 - name : Checkout code
4458 uses : actions/checkout@v4
4559 with :
4660 fetch-depth : 0
4761
4862 - name : Install tunix dependencies
4963 run : |
50- pip install -e .
51- pip install pytest pytest-xdist jinja2
64+ pip install -e .[prod]
65+ pip install pytest pytest-xdist
5266
5367 - name : Verify TPU availability
5468 run : |
8599
86100 - name : Run tunix generation tests (PASSED only)
87101 run : |
88- # vllm_sampler_test depends on vllm TPU which is not OSS yet
89102 # tokenizer_adapter_test requires access to gated repo
90103 python -m pytest tests/generate/ -v --tb=short \
91104 --ignore=tests/generate/vllm_sampler_test.py \
94107 - name : Run tunix SFT tests
95108 run : |
96109 python -m pytest tests/sft/ -v --tb=short
97-
110+
98111 - name : Run tunix SFT integration tests
99112 env :
100113 HF_TOKEN : ${{ secrets.HF_TOKEN }}
@@ -115,28 +128,28 @@ jobs:
115128 env :
116129 HF_TOKEN : ${{ secrets.HF_TOKEN }}
117130 run : |
118-
131+
119132 # Download GSM8K dataset
120133 mkdir -p /tmp/grpo_test/rl/grpo/data
121134 python3 -c "
122135 from datasets import load_dataset
123136 import json
124-
137+
125138 # Download and save GSM8K train split
126139 dataset = load_dataset('openai/gsm8k', 'main', split='train')
127140 train_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
128141 with open('/tmp/grpo_test/rl/grpo/data/gsm8k_train.json', 'w') as f:
129142 json.dump(train_data, f)
130-
143+
131144 # Download and save GSM8K test split
132145 dataset = load_dataset('openai/gsm8k', 'main', split='test')
133146 test_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
134147 with open('/tmp/grpo_test/rl/grpo/data/gsm8k_test.json', 'w') as f:
135148 json.dump(test_data, f)
136-
149+
137150 print('GSM8K dataset downloaded successfully')
138151 "
139-
152+
140153 # Run GRPO demo script with minimal configuration
141154 python3 scripts/grpo_demo_llama3_qwen2.py \
142155 --root-dir=/tmp/grpo_test \
@@ -156,3 +169,48 @@ jobs:
156169 exit "${code:-0}"
157170 fi
158171
172+ run_dev :
173+ runs-on : [linux-x86-ct5lp-224-8tpu]
174+ environment : testing
175+ container :
176+ image : vllm/vllm-tpu:v0.11.1
177+ options : --privileged
178+ env :
179+ CLOUD_TPU_ACCELERATOR : v5e-8
180+ JAX_PLATFORMS : tpu
181+ steps :
182+ # Cache Hugging Face hub
183+ - name : Cache HF hub
184+ uses : actions/cache@v4
185+ with :
186+ path : ~/.cache/huggingface
187+ key : hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
188+ restore-keys : |
189+ hf-${{ runner.os }}-
190+
191+ - name : Checkout code
192+ uses : actions/checkout@v4
193+ with :
194+ fetch-depth : 0
195+
196+ - name : Setup Tunix and tpu-inference
197+ run : |
198+ echo "Current directory:"
199+ pwd
200+ pip install --upgrade pip setuptools wheel
201+
202+ # Install Tunix
203+ pip uninstall torch torch-xla libtpu jax jaxlib -y
204+ pip install -e .[dev]
205+
206+ # Install tpu-inference
207+ pip uninstall torch libtpu jax jaxlib -y
208+ pip install tpu-inference==v0.11.1 --force-reinstall
209+ pip install pytest pytest-xdist
210+
211+ - name : Run tests
212+ env :
213+ HF_TOKEN : ${{ secrets.HF_TOKEN }}
214+ run : |
215+ pytest tests/generate/vllm_sampler_test.py -v --tb=short
216+
0 commit comments