Skip to content

Commit dc4da4a

Browse files
authored
update trl to 0.17.0 (axolotl-ai-cloud#2560)
* update trl to 0.17.0 * grpo + vllm no longer supported with 2.5.1 due to vllm constraints * disable VLLM_USE_V1 for ci * imporve handle killing off of multiprocessing vllm service * debug why this doesn't run in CI * increase vllm wait time * increase timeout to 5min * upgrade to vllm 0.8.4 * dump out the vllm log for debugging * use debug logging * increase vllm start timeout * use NVL instead * disable torch compile cache * revert some commented checks now that grpo tests are fixed * increase vllm timeoout back to 5min
1 parent f9c7c3b commit dc4da4a

File tree

7 files changed

+93
-30
lines changed

7 files changed

+93
-30
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
cuda_version: 12.4.1
2525
python_version: "3.11"
2626
pytorch: 2.5.1
27-
axolotl_extras: vllm
27+
axolotl_extras:
2828
- cuda: 124
2929
cuda_version: 12.4.1
3030
python_version: "3.11"

.github/workflows/multi-gpu-e2e.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
cuda_version: 12.4.1
4444
python_version: "3.11"
4545
pytorch: 2.5.1
46-
axolotl_extras: vllm
46+
axolotl_extras:
4747
num_gpus: 2
4848
nightly_build: "true"
4949
- cuda: 126

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ jobs:
269269
python_version: "3.11"
270270
pytorch: 2.5.1
271271
num_gpus: 1
272-
axolotl_extras: vllm
272+
axolotl_extras:
273273
- cuda: 126
274274
cuda_version: 12.6.3
275275
python_version: "3.11"

cicd/multigpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
2020
--cov-report=xml:multigpu-coverage.xml
2121

2222
# Upload coverage to Codecov
23-
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
23+
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ liger-kernel==0.5.8
1111

1212
packaging==23.2
1313

14-
peft==0.15.1
14+
peft==0.15.2
1515
transformers==4.51.3
1616
tokenizers>=0.21.1
1717
accelerate==1.6.0
1818
datasets==3.5.0
1919
deepspeed>=0.15.4
20-
trl==0.16.1
20+
trl==0.17.0
2121
hf_xet==1.0.0
2222
hqq==0.2.5
2323

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
6767
if (major, minor) >= (2, 7):
6868
_install_requires.pop(_install_requires.index(xformers_version))
6969
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
70-
extras_require_map["vllm"] = ["vllm==0.8.3"]
70+
extras_require_map["vllm"] = ["vllm==0.8.4"]
7171
elif (major, minor) >= (2, 6):
7272
_install_requires.pop(_install_requires.index(xformers_version))
7373
_install_requires.append(
7474
"xformers==0.0.29.post2"
7575
) # vllm needs post2 w torch 2.6
76-
extras_require_map["vllm"] = ["vllm==0.8.3"]
76+
extras_require_map["vllm"] = ["vllm==0.8.4"]
7777
elif (major, minor) >= (2, 5):
7878
_install_requires.pop(_install_requires.index(xformers_version))
7979
if patch == 0:

tests/e2e/multigpu/solo/test_grpo.py

Lines changed: 85 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
import os
66
import random
7+
import shutil
78
import subprocess # nosec B404
89
import sys
10+
import tempfile
911
import time
1012
from pathlib import Path
1113

14+
import psutil
1215
import pytest
1316
import requests
1417
import yaml
@@ -21,8 +24,8 @@
2124

2225

2326
def start_vllm(
24-
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
25-
) -> int:
27+
model: str, env: dict, wait: int | None = None, quiet=False, **kwargs
28+
) -> subprocess.Popen:
2629
"""
2730
helper function to start the VLLM server in the background, mostly for testing purposes
2831
"""
@@ -46,10 +49,41 @@ def start_vllm(
4649
# print out the command to be executed
4750
print(" ".join(cmd))
4851

52+
vllm_logging_json = Path(tempfile.mkdtemp()) / "vllm_logging.json"
53+
with open(vllm_logging_json, "w", encoding="utf-8") as temp_file:
54+
temp_file.write(
55+
"""{
56+
"formatters": {
57+
"json": {
58+
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
59+
}
60+
},
61+
"handlers": {
62+
"file": {
63+
"class": "logging.FileHandler",
64+
"formatter": "json",
65+
"level": "DEBUG",
66+
"filename": "/tmp/vllm.log",
67+
"mode": "a"
68+
}
69+
},
70+
"loggers": {
71+
"vllm": {
72+
"handlers": ["file"],
73+
"level": "DEBUG",
74+
"propagate": false
75+
}
76+
},
77+
"version": 1
78+
}"""
79+
)
80+
81+
cmd_env = env.copy()
82+
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
4983
# start `trl vllm-serve` command in the background and capture the process id
5084
process = subprocess.Popen( # pylint: disable=consider-using-with
5185
cmd,
52-
env=env,
86+
env=cmd_env,
5387
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
5488
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
5589
) # nosec B603
@@ -58,32 +92,51 @@ def start_vllm(
5892
print(f"VLLM server process started (PID: {process.pid})")
5993

6094
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
95+
period_seconds = 5
6196
started = False
6297
if wait and host and port:
63-
for _ in range(int(wait)):
98+
for i in range(0, int(wait), period_seconds):
6499
try:
65100
response = requests.get(f"http://{host}:{port}", timeout=1)
101+
print(f"{i}: VLLM server (status: {response.status_code})")
66102
if int(response.status_code) in [200, 404]:
67103
started = True
68104
break
69-
except requests.exceptions.RequestException:
70-
pass
105+
except requests.exceptions.RequestException as exc:
106+
print(f"{i}: VLLM server failed to start: {str(exc)}")
71107

72108
# also check if the process.pid is still running
73109
if not process.poll() is None:
74110
break
75111

76-
time.sleep(1)
112+
time.sleep(period_seconds)
77113

78114
if wait and not started:
79115
print(
80116
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
81117
)
82-
process.kill()
118+
recursive_kill(process)
119+
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
120+
print(log_file.read())
121+
shutil.rmtree("/tmp/vllm.log")
83122
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
84123

85-
# return the process id
86-
return process.pid
124+
# return the process
125+
return process
126+
127+
128+
def recursive_kill(process: subprocess.Popen):
129+
"""
130+
Recursively kill a process and its children
131+
"""
132+
process = psutil.Process(process.pid)
133+
for child in psutil.Process(process.pid).children(recursive=True):
134+
child.terminate()
135+
child.kill()
136+
os.kill(child.pid, 9)
137+
process.terminate()
138+
process.kill()
139+
os.kill(process.pid, 9)
87140

88141

89142
class TestGRPO:
@@ -174,16 +227,17 @@ def test_llama_dora(self, temp_dir, num_gpus):
174227

175228
current_env = os.environ.copy()
176229
env = {
177-
"NCCL_P2P_LEVEL": "LOC",
230+
"NCCL_P2P_LEVEL": "NVL",
178231
**current_env,
179232
"CUDA_VISIBLE_DEVICES": "1",
180-
"VLLM_USE_V1": "0",
233+
"VLLM_DISABLE_COMPILE_CACHE": "1",
234+
# "VLLM_USE_V1": "0",
181235
}
182-
vllm_process_id = start_vllm(
236+
vllm_process = start_vllm(
183237
cfg.base_model,
184238
env=env,
185239
quiet=True,
186-
wait=120,
240+
wait=300,
187241
gpu_memory_utilization=0.15,
188242
max_model_len=cfg.vllm.max_model_len,
189243
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -202,10 +256,14 @@ def test_llama_dora(self, temp_dir, num_gpus):
202256
"--main-process-port",
203257
f"{get_torch_dist_unique_port()}",
204258
],
205-
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
259+
env={
260+
"NCCL_P2P_LEVEL": "NVL",
261+
"NCCL_DEBUG": "INFO",
262+
**current_env,
263+
},
206264
)
207265
finally:
208-
os.kill(vllm_process_id, 9)
266+
recursive_kill(vllm_process)
209267

210268
@pytest.mark.parametrize(
211269
"num_gpus",
@@ -262,16 +320,17 @@ def test_llama_fft(self, temp_dir, num_gpus):
262320

263321
current_env = os.environ.copy()
264322
env = {
265-
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
323+
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
266324
**current_env,
267325
"CUDA_VISIBLE_DEVICES": "1",
268-
"VLLM_USE_V1": "0",
326+
"VLLM_DISABLE_COMPILE_CACHE": "1",
327+
# "VLLM_USE_V1": "0",
269328
}
270-
vllm_process_id = start_vllm(
329+
vllm_process = start_vllm(
271330
cfg.base_model,
272331
env=env,
273332
quiet=True,
274-
wait=120,
333+
wait=300,
275334
gpu_memory_utilization=0.15,
276335
max_model_len=cfg.vllm.max_model_len,
277336
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -290,7 +349,11 @@ def test_llama_fft(self, temp_dir, num_gpus):
290349
"--main-process-port",
291350
f"{get_torch_dist_unique_port()}",
292351
],
293-
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
352+
env={
353+
"NCCL_P2P_LEVEL": "NVL",
354+
"NCCL_DEBUG": "INFO",
355+
**current_env,
356+
},
294357
)
295358
finally:
296-
os.kill(vllm_process_id, 9)
359+
recursive_kill(vllm_process)

0 commit comments

Comments
 (0)