Skip to content

Commit 0938de2

Browse files
committed
Fix DeepEP installation; use overlay0 interface to get container rank; add Ray dashboard deps for data parallel; remove OOM parallel configs
1 parent a4d8284 commit 0938de2

File tree

1 file changed

+67
-51
lines changed

1 file changed

+67
-51
lines changed

k2-inference/main.py

Lines changed: 67 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import re
23
import subprocess
34
import time
45

@@ -12,7 +13,8 @@
1213
.apt_install("libibverbs-dev", "libibverbs1")
1314
.run_commands(
1415
"uv pip install --system -U uv",
15-
"uv pip install --system blobfile==3.0.0 requests==2.32.4",
16+
"uv pip install --system blobfile==3.0.0 requests==2.32.4 psutil",
17+
"uv pip install --system ray[default]",
1618
# using nightly until they cut a new release (current stable is v0.9.2)
1719
# we need this vllm commit to use pipeline parallelism with kimi:
1820
# https://github.com/vllm-project/vllm/commit/ad6c2e1a0b56c29065c7d70ff2e736e4f2fb03af
@@ -23,16 +25,22 @@
2325
# and when 2+ processes participate in collective ops. upgrading to 2.27+ fixes this
2426
"uv pip install --system -U nvidia-nccl-cu12==2.27.6",
2527
)
26-
.apt_install("git", "build-essential", "g++")
28+
.apt_install("git", "build-essential", "g++", "wget")
2729
.run_commands(
30+
"uv pip install --system cuda-bindings",
2831
# recursive bc DeepGEMM vendors CUTLASS for the build
29-
"git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git",
32+
"git clone https://github.com/deepseek-ai/DeepGEMM.git",
33+
# latest commit on main broke authless recursive clone, thus:
34+
"cd DeepGEMM && git checkout 03d0be3d2d03b6eed3c99d683c0620949a13a826",
35+
"cd DeepGEMM && git submodule update --init --recursive",
3036
"uv pip install --system ./DeepGEMM",
3137
)
3238
.run_commands(
3339
"uv pip install --system nvidia-nvshmem-cu12==3.3.9",
3440
"git clone https://github.com/deepseek-ai/DeepEP.git",
35-
"NVSHMEM_DIR=$(python -c 'import nvidia.nvshmem; import os; print(os.path.dirname(nvidia.nvshmem.__file__))' 2>/dev/null) CXX=g++ uv pip install --system ./DeepEP --no-build-isolation",
41+
# nvidia-nvshmem-cu12 ships with versioned binaries, but the DeepEP build process expects unversioned. sigh...
42+
"cd $(python -c 'import nvidia.nvshmem; import os; print(nvidia.nvshmem.__path__[0])') && cp lib/libnvshmem_host.so.3 lib/libnvshmem_host.so",
43+
"NVSHMEM_DIR=$(python -c 'import nvidia.nvshmem; import os; print(nvidia.nvshmem.__path__[0])') CXX=g++ uv pip install --system ./DeepEP --no-build-isolation",
3644
)
3745
.env({"RAY_DISABLE_DOCKER_CPU_WARNING": "1", "VLLM_USE_DEEPGEMM": "1"})
3846
)
@@ -42,7 +50,8 @@
4250
# Volume for Hugging Face cache
4351
hf_cache_volume = modal.Volume.from_name("big-model-hfcache")
4452
vllm_cache_volume = modal.Volume.from_name(
45-
"k2-multinode-vllmcache", create_if_missing=True, version=2
53+
"k2-multinode-vllmcache",
54+
create_if_missing=True,
4655
)
4756

4857
# Ray configuration
@@ -53,6 +62,7 @@
5362
MODEL = "moonshotai/Kimi-K2-Instruct"
5463

5564
with image.imports():
65+
import psutil
5666
import requests
5767

5868

@@ -65,16 +75,18 @@ class K2Inference:
6575
pp_size: int
6676
dp_size: int
6777
max_seqs: int
78+
nodes: int
6879
max_model_len: int = 128000
6980
enable_expert_parallel: bool = False
7081

7182
@modal.enter()
7283
def setup(self):
73-
container_rank = _spawn_ray_nodes()
84+
container_rank = _spawn_ray_nodes(self.nodes)
7485
vllm_cmd = _build_vllm_cmd(
7586
self.tp_size,
7687
self.pp_size,
7788
self.dp_size,
89+
self.nodes,
7890
self.max_seqs,
7991
self.max_model_len,
8092
self.enable_expert_parallel,
@@ -133,13 +145,28 @@ def cleanup(self):
133145
self.flash_handle.close()
134146

135147

136-
def _spawn_ray_nodes():
148+
def get_overlay0_address():
149+
"""Get the IP address of overlay0 interface using psutil"""
150+
try:
151+
interfaces = psutil.net_if_addrs()
152+
if "overlay0" in interfaces:
153+
for addr in interfaces["overlay0"]:
154+
if addr.family == 2: # AF_INET (IPv4)
155+
ip_parts = addr.address.split(".")
156+
last_octet = int(ip_parts[-1])
157+
return addr.address, last_octet - 1
158+
return None, None
159+
except Exception as e:
160+
print(f"Error: {e}")
161+
return None, None
162+
163+
164+
def _spawn_ray_nodes(num_nodes):
137165
# Get cluster information
138-
cluster_info = modal.experimental.get_cluster_info()
139-
container_rank = cluster_info.rank
140-
container_v4_ips = [
141-
f"10.100.0.{rank + 1}" for rank in range(len(cluster_info.container_ips))
142-
]
166+
ipv4_addr, container_rank = get_overlay0_address()
167+
assert ipv4_addr is not None and container_rank is not None
168+
169+
container_v4_ips = [f"10.100.0.{rank + 1}" for rank in range(num_nodes)]
143170
main_addr_v4 = container_v4_ips[0]
144171
this_v4 = container_v4_ips[container_rank]
145172

@@ -159,6 +186,7 @@ def _spawn_ray_nodes():
159186
f"--port={RAY_PORT}", # 6379
160187
"--dashboard-host=0.0.0.0",
161188
f"--dashboard-port={RAY_DASHBOARD_PORT}", # 8265
189+
"--include-dashboard=True",
162190
"--block",
163191
]
164192
else:
@@ -197,6 +225,7 @@ def _build_vllm_cmd(
197225
tp_size: int,
198226
pp_size: int,
199227
dp_size: int,
228+
num_nodes: int,
200229
max_seqs: int,
201230
max_model_len: int,
202231
enable_expert_parallel: bool,
@@ -206,35 +235,43 @@ def _build_vllm_cmd(
206235
vllm_cmd = [
207236
"vllm",
208237
"serve",
209-
"--model",
210238
MODEL,
211239
"--download-dir",
212240
"/root/.cache/huggingface",
213-
"--model-name",
241+
"--served-model-name",
214242
"kimi-k2",
243+
"--enable-auto-tool-choice",
244+
"--tool-call-parser",
245+
"kimi_k2",
215246
"--trust-remote-code",
216247
"--host",
217248
"0.0.0.0",
218249
"--port",
219250
str(VLLM_PORT),
251+
"--max-model-len",
252+
str(max_model_len),
253+
"--max-num-seqs",
254+
str(max_seqs),
255+
"--gpu-memory-utilization",
256+
"0.95",
220257
"--distributed-executor-backend",
221258
"ray",
222259
"--tensor-parallel-size",
223260
str(tp_size),
224261
"--pipeline-parallel-size",
225262
str(pp_size),
226-
"--max-model-len",
227-
str(max_model_len),
228-
"--max-num-seqs",
229-
str(max_seqs),
230263
]
231264
if dp_size > 1:
265+
dp_size_local = dp_size // num_nodes if dp_size >= num_nodes else 1
266+
232267
vllm_cmd.extend(
233268
[
234269
"--data-parallel-backend",
235270
"ray",
236271
"--data-parallel-size",
237272
str(dp_size),
273+
"--data-parallel-size-local",
274+
str(dp_size_local),
238275
]
239276
)
240277
if enable_expert_parallel:
@@ -244,29 +281,6 @@ def _build_vllm_cmd(
244281
return vllm_cmd
245282

246283

247-
# probably not a great idea due to potential expert load imbalance issues,
248-
# but worth a test
249-
@app.cls(
250-
image=image,
251-
gpu="H100:8",
252-
volumes={
253-
"/root/.cache/huggingface": hf_cache_volume,
254-
"/root/.cache/vllm": vllm_cache_volume,
255-
},
256-
timeout=60 * 60 * 1,
257-
experimental_options={"flash": "us-east"},
258-
)
259-
@modal.experimental.clustered(size=2, rdma=True)
260-
class K2Tp8Ep2(K2Inference):
261-
# 2x8H100
262-
# tp=8,pp=1,ep=2,dp=1
263-
tp_size = 8
264-
pp_size = 1
265-
dp_size = 1
266-
max_seqs = 256
267-
enable_expert_parallel = True
268-
269-
270284
@app.cls(
271285
image=image,
272286
gpu="H100:8",
@@ -278,14 +292,15 @@ class K2Tp8Ep2(K2Inference):
278292
experimental_options={"flash": "us-east"},
279293
)
280294
@modal.experimental.clustered(size=4, rdma=True)
281-
class K2Tp8Ep4(K2Inference):
282-
# low-latency
295+
class K2Tp8Dp2Ep2(K2Inference):
283296
# 4x8H100
284-
# tp=8,pp=1,ep=4,dp=1
297+
# tp=8,pp=1,ep=2,dp=2
285298
tp_size = 8
286299
pp_size = 1
287-
dp_size = 1
288-
max_seqs = 256
300+
dp_size = 2
301+
nodes = 4
302+
max_seqs = 2
303+
max_model_len = 64000
289304
enable_expert_parallel = True
290305

291306

@@ -300,12 +315,13 @@ class K2Tp8Ep4(K2Inference):
300315
experimental_options={"flash": "us-east"},
301316
)
302317
@modal.experimental.clustered(size=4, rdma=True)
303-
class K2Tp8Ep2Dp2(K2Inference):
304-
# high throughput
318+
class K2Tp8Pp2Ep2(K2Inference):
305319
# 4x8H100
306-
# tp=8,pp=1,ep=2,dp=2
320+
# tp=8,pp=2,ep=2,dp=1
307321
tp_size = 8
308-
pp_size = 1
309-
dp_size = 2
322+
pp_size = 2
323+
dp_size = 1
324+
nodes = 4
310325
max_seqs = 256
326+
max_model_len = 64000
311327
enable_expert_parallel = True

0 commit comments

Comments
 (0)