4
4
5
5
import os
6
6
import random
7
+ import shutil
7
8
import subprocess # nosec B404
8
9
import sys
10
+ import tempfile
9
11
import time
10
12
from pathlib import Path
11
13
14
+ import psutil
12
15
import pytest
13
16
import requests
14
17
import yaml
21
24
22
25
23
26
def start_vllm (
24
- model : str , env : dict | None = None , wait : int | None = None , quiet = False , ** kwargs
25
- ) -> int :
27
+ model : str , env : dict , wait : int | None = None , quiet = False , ** kwargs
28
+ ) -> subprocess . Popen :
26
29
"""
27
30
helper function to start the VLLM server in the background, mostly for testing purposes
28
31
"""
@@ -46,10 +49,41 @@ def start_vllm(
46
49
# print out the command to be executed
47
50
print (" " .join (cmd ))
48
51
52
+ vllm_logging_json = Path (tempfile .mkdtemp ()) / "vllm_logging.json"
53
+ with open (vllm_logging_json , "w" , encoding = "utf-8" ) as temp_file :
54
+ temp_file .write (
55
+ """{
56
+ "formatters": {
57
+ "json": {
58
+ "class": "pythonjsonlogger.jsonlogger.JsonFormatter"
59
+ }
60
+ },
61
+ "handlers": {
62
+ "file": {
63
+ "class": "logging.FileHandler",
64
+ "formatter": "json",
65
+ "level": "DEBUG",
66
+ "filename": "/tmp/vllm.log",
67
+ "mode": "a"
68
+ }
69
+ },
70
+ "loggers": {
71
+ "vllm": {
72
+ "handlers": ["file"],
73
+ "level": "DEBUG",
74
+ "propagate": false
75
+ }
76
+ },
77
+ "version": 1
78
+ }"""
79
+ )
80
+
81
+ cmd_env = env .copy ()
82
+ cmd_env .update ({"VLLM_LOGGING_CONFIG_PATH" : vllm_logging_json })
49
83
# start `trl vllm-serve` command in the background and capture the process id
50
84
process = subprocess .Popen ( # pylint: disable=consider-using-with
51
85
cmd ,
52
- env = env ,
86
+ env = cmd_env ,
53
87
stdout = subprocess .DEVNULL if quiet else subprocess .PIPE ,
54
88
stderr = subprocess .DEVNULL if quiet else subprocess .PIPE ,
55
89
) # nosec B603
@@ -58,32 +92,51 @@ def start_vllm(
58
92
print (f"VLLM server process started (PID: { process .pid } )" )
59
93
60
94
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
95
+ period_seconds = 5
61
96
started = False
62
97
if wait and host and port :
63
- for _ in range (int (wait )):
98
+ for i in range (0 , int (wait ), period_seconds ):
64
99
try :
65
100
response = requests .get (f"http://{ host } :{ port } " , timeout = 1 )
101
+ print (f"{ i } : VLLM server (status: { response .status_code } )" )
66
102
if int (response .status_code ) in [200 , 404 ]:
67
103
started = True
68
104
break
69
- except requests .exceptions .RequestException :
70
- pass
105
+ except requests .exceptions .RequestException as exc :
106
+ print ( f" { i } : VLLM server failed to start: { str ( exc ) } " )
71
107
72
108
# also check if the process.pid is still running
73
109
if not process .poll () is None :
74
110
break
75
111
76
- time .sleep (1 )
112
+ time .sleep (period_seconds )
77
113
78
114
if wait and not started :
79
115
print (
80
116
f"VLLM server process did not start within { wait } seconds. Please check your server logs."
81
117
)
82
- process .kill ()
118
+ recursive_kill (process )
119
+ with open ("/tmp/vllm.log" , "r" , encoding = "utf-8" ) as log_file :
120
+ print (log_file .read ())
121
+ shutil .rmtree ("/tmp/vllm.log" )
83
122
raise RuntimeError (f"VLLM server process did not start within { wait } seconds." )
84
123
85
- # return the process id
86
- return process .pid
124
+ # return the process
125
+ return process
126
+
127
+
128
+ def recursive_kill (process : subprocess .Popen ):
129
+ """
130
+ Recursively kill a process and its children
131
+ """
132
+ process = psutil .Process (process .pid )
133
+ for child in psutil .Process (process .pid ).children (recursive = True ):
134
+ child .terminate ()
135
+ child .kill ()
136
+ os .kill (child .pid , 9 )
137
+ process .terminate ()
138
+ process .kill ()
139
+ os .kill (process .pid , 9 )
87
140
88
141
89
142
class TestGRPO :
@@ -174,16 +227,17 @@ def test_llama_dora(self, temp_dir, num_gpus):
174
227
175
228
current_env = os .environ .copy ()
176
229
env = {
177
- "NCCL_P2P_LEVEL" : "LOC " ,
230
+ "NCCL_P2P_LEVEL" : "NVL " ,
178
231
** current_env ,
179
232
"CUDA_VISIBLE_DEVICES" : "1" ,
180
- "VLLM_USE_V1" : "0" ,
233
+ "VLLM_DISABLE_COMPILE_CACHE" : "1" ,
234
+ # "VLLM_USE_V1": "0",
181
235
}
182
- vllm_process_id = start_vllm (
236
+ vllm_process = start_vllm (
183
237
cfg .base_model ,
184
238
env = env ,
185
239
quiet = True ,
186
- wait = 120 ,
240
+ wait = 300 ,
187
241
gpu_memory_utilization = 0.15 ,
188
242
max_model_len = cfg .vllm .max_model_len ,
189
243
enable_prefix_caching = cfg .vllm .enable_prefix_caching ,
@@ -202,10 +256,14 @@ def test_llama_dora(self, temp_dir, num_gpus):
202
256
"--main-process-port" ,
203
257
f"{ get_torch_dist_unique_port ()} " ,
204
258
],
205
- env = {"NCCL_P2P_LEVEL" : "LOC" , "NCCL_DEBUG" : "INFO" , ** current_env },
259
+ env = {
260
+ "NCCL_P2P_LEVEL" : "NVL" ,
261
+ "NCCL_DEBUG" : "INFO" ,
262
+ ** current_env ,
263
+ },
206
264
)
207
265
finally :
208
- os . kill ( vllm_process_id , 9 )
266
+ recursive_kill ( vllm_process )
209
267
210
268
@pytest .mark .parametrize (
211
269
"num_gpus" ,
@@ -262,16 +320,17 @@ def test_llama_fft(self, temp_dir, num_gpus):
262
320
263
321
current_env = os .environ .copy ()
264
322
env = {
265
- "NCCL_P2P_LEVEL" : "LOC " , # nccl can be brittle, assume P2P isn't reliable
323
+ "NCCL_P2P_LEVEL" : "NVL " , # nccl can be brittle, assume P2P isn't reliable
266
324
** current_env ,
267
325
"CUDA_VISIBLE_DEVICES" : "1" ,
268
- "VLLM_USE_V1" : "0" ,
326
+ "VLLM_DISABLE_COMPILE_CACHE" : "1" ,
327
+ # "VLLM_USE_V1": "0",
269
328
}
270
- vllm_process_id = start_vllm (
329
+ vllm_process = start_vllm (
271
330
cfg .base_model ,
272
331
env = env ,
273
332
quiet = True ,
274
- wait = 120 ,
333
+ wait = 300 ,
275
334
gpu_memory_utilization = 0.15 ,
276
335
max_model_len = cfg .vllm .max_model_len ,
277
336
enable_prefix_caching = cfg .vllm .enable_prefix_caching ,
@@ -290,7 +349,11 @@ def test_llama_fft(self, temp_dir, num_gpus):
290
349
"--main-process-port" ,
291
350
f"{ get_torch_dist_unique_port ()} " ,
292
351
],
293
- env = {"NCCL_P2P_LEVEL" : "LOC" , "NCCL_DEBUG" : "INFO" , ** current_env },
352
+ env = {
353
+ "NCCL_P2P_LEVEL" : "NVL" ,
354
+ "NCCL_DEBUG" : "INFO" ,
355
+ ** current_env ,
356
+ },
294
357
)
295
358
finally :
296
- os . kill ( vllm_process_id , 9 )
359
+ recursive_kill ( vllm_process )
0 commit comments