Skip to content

Commit e6da57e

Browse files
committed
Fix DeepEP installation; use overlay0 interface to get container rank; add Ray dashboard deps for data parallel; remove OOM parallel configs
1 parent e397255 commit e6da57e

File tree

1 file changed

+77
-53
lines changed

1 file changed

+77
-53
lines changed

k2-inference/main.py

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import re
23
import subprocess
34
import time
45

@@ -12,7 +13,8 @@
1213
.apt_install("libibverbs-dev", "libibverbs1")
1314
.run_commands(
1415
"uv pip install --system -U uv",
15-
"uv pip install --system blobfile==3.0.0 requests==2.32.4",
16+
"uv pip install --system blobfile==3.0.0 requests==2.32.4 psutil",
17+
"uv pip install --system ray[default]",
1618
# using nightly until they cut a new release (current stable is v0.9.2)
1719
# we need this vllm commit to use pipeline parallelism with kimi:
1820
# https://github.com/vllm-project/vllm/commit/ad6c2e1a0b56c29065c7d70ff2e736e4f2fb03af
@@ -23,16 +25,22 @@
2325
# and when 2+ processes participate in collective ops. upgrading to 2.27+ fixes this
2426
"uv pip install --system -U nvidia-nccl-cu12==2.27.6",
2527
)
26-
.apt_install("git", "build-essential", "g++")
28+
.apt_install("git", "build-essential", "g++", "wget")
2729
.run_commands(
30+
"uv pip install --system cuda-bindings",
2831
# recursive bc DeepGEMM vendors CUTLASS for the build
29-
"git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git",
32+
"git clone https://github.com/deepseek-ai/DeepGEMM.git",
33+
# latest commit on main broke authless recursive clone, thus:
34+
"cd DeepGEMM && git checkout 03d0be3d2d03b6eed3c99d683c0620949a13a826",
35+
"cd DeepGEMM && git submodule update --init --recursive",
3036
"uv pip install --system ./DeepGEMM",
3137
)
3238
.run_commands(
3339
"uv pip install --system nvidia-nvshmem-cu12==3.3.9",
3440
"git clone https://github.com/deepseek-ai/DeepEP.git",
35-
"NVSHMEM_DIR=$(python -c 'import nvidia.nvshmem; import os; print(os.path.dirname(nvidia.nvshmem.__file__))' 2>/dev/null) CXX=g++ uv pip install --system ./DeepEP --no-build-isolation",
41+
# nvidia-nvshmem-cu12 ships with versioned binaries, but the DeepEP build process expects unversioned. sigh...
42+
"cd $(python -c 'import nvidia.nvshmem; import os; print(nvidia.nvshmem.__path__[0])') && cp lib/libnvshmem_host.so.3 lib/libnvshmem_host.so",
43+
"NVSHMEM_DIR=$(python -c 'import nvidia.nvshmem; import os; print(nvidia.nvshmem.__path__[0])') CXX=g++ uv pip install --system ./DeepEP --no-build-isolation",
3644
)
3745
.env({"RAY_DISABLE_DOCKER_CPU_WARNING": "1", "VLLM_USE_DEEPGEMM": "1"})
3846
)
@@ -42,7 +50,8 @@
4250
# Volume for Hugging Face cache
4351
hf_cache_volume = modal.Volume.from_name("big-model-hfcache")
4452
vllm_cache_volume = modal.Volume.from_name(
45-
"k2-multinode-vllmcache", create_if_missing=True, version=2
53+
"k2-multinode-vllmcache",
54+
create_if_missing=True,
4655
)
4756

4857
# Ray configuration
@@ -53,6 +62,7 @@
5362
MODEL = "moonshotai/Kimi-K2-Instruct"
5463

5564
with image.imports():
65+
import psutil
5666
import requests
5767

5868

@@ -65,16 +75,18 @@ class K2Inference:
6575
pp_size: int
6676
dp_size: int
6777
max_seqs: int
78+
nodes: int
6879
max_model_len: int = 128000
6980
enable_expert_parallel: bool = False
7081

7182
@modal.enter()
7283
def setup(self):
73-
container_rank = _spawn_ray_nodes()
84+
container_rank = _spawn_ray_nodes(self.nodes)
7485
vllm_cmd = _build_vllm_cmd(
7586
self.tp_size,
7687
self.pp_size,
7788
self.dp_size,
89+
self.nodes,
7890
self.max_seqs,
7991
self.max_model_len,
8092
self.enable_expert_parallel,
@@ -102,8 +114,13 @@ def server(self):
102114

103115
@modal.exit()
104116
def cleanup(self):
105-
cluster_info = modal.experimental.get_cluster_info()
106-
if cluster_info.rank == 0:
117+
_, container_rank = get_overlay0_address()
118+
if container_rank is None:
119+
print(
120+
"WARNING: Failed to infer container rank, exiting early. Any processing requests will be cancelled."
121+
)
122+
return
123+
if container_rank == 0:
107124
self.flash_handle.stop()
108125

109126
deadline = time.time() + 60 # 1 minute deadline
@@ -133,13 +150,31 @@ def cleanup(self):
133150
self.flash_handle.close()
134151

135152

136-
def _spawn_ray_nodes():
153+
def get_overlay0_address():
154+
"""Get the IP address of overlay0 interface using psutil"""
155+
try:
156+
interfaces = psutil.net_if_addrs()
157+
if "overlay0" in interfaces:
158+
for addr in interfaces["overlay0"]:
159+
if addr.family == 2: # AF_INET (IPv4)
160+
ip_parts = addr.address.split(".")
161+
last_octet = int(ip_parts[-1])
162+
return addr.address, last_octet - 1
163+
return None, None
164+
except Exception as e:
165+
print(f"Error: {e}")
166+
return None, None
167+
168+
169+
def _spawn_ray_nodes(num_nodes):
137170
# Get cluster information
138-
cluster_info = modal.experimental.get_cluster_info()
139-
container_rank = cluster_info.rank
140-
container_v4_ips = [
141-
f"10.100.0.{rank + 1}" for rank in range(len(cluster_info.container_ips))
142-
]
171+
ipv4_addr, container_rank = get_overlay0_address()
172+
if ipv4_addr is None:
173+
raise RuntimeError(
174+
"Could not infer container rank, because `overlay0` network interface is missing"
175+
)
176+
177+
container_v4_ips = [f"10.100.0.{rank + 1}" for rank in range(num_nodes)]
143178
main_addr_v4 = container_v4_ips[0]
144179
this_v4 = container_v4_ips[container_rank]
145180

@@ -159,6 +194,7 @@ def _spawn_ray_nodes():
159194
f"--port={RAY_PORT}", # 6379
160195
"--dashboard-host=0.0.0.0",
161196
f"--dashboard-port={RAY_DASHBOARD_PORT}", # 8265
197+
"--include-dashboard=True",
162198
"--block",
163199
]
164200
else:
@@ -197,6 +233,7 @@ def _build_vllm_cmd(
197233
tp_size: int,
198234
pp_size: int,
199235
dp_size: int,
236+
num_nodes: int,
200237
max_seqs: int,
201238
max_model_len: int,
202239
enable_expert_parallel: bool,
@@ -209,31 +246,40 @@ def _build_vllm_cmd(
209246
MODEL,
210247
"--download-dir",
211248
"/root/.cache/huggingface",
212-
"--model-name",
249+
"--served-model-name",
213250
"kimi-k2",
251+
"--enable-auto-tool-choice",
252+
"--tool-call-parser",
253+
"kimi_k2",
214254
"--trust-remote-code",
215255
"--host",
216256
"0.0.0.0",
217257
"--port",
218258
str(VLLM_PORT),
259+
"--max-model-len",
260+
str(max_model_len),
261+
"--max-num-seqs",
262+
str(max_seqs),
263+
"--gpu-memory-utilization",
264+
"0.95",
219265
"--distributed-executor-backend",
220266
"ray",
221267
"--tensor-parallel-size",
222268
str(tp_size),
223269
"--pipeline-parallel-size",
224270
str(pp_size),
225-
"--max-model-len",
226-
str(max_model_len),
227-
"--max-num-seqs",
228-
str(max_seqs),
229271
]
230272
if dp_size > 1:
273+
dp_size_local = dp_size // num_nodes if dp_size >= num_nodes else 1
274+
231275
vllm_cmd.extend(
232276
[
233277
"--data-parallel-backend",
234278
"ray",
235279
"--data-parallel-size",
236280
str(dp_size),
281+
"--data-parallel-size-local",
282+
str(dp_size_local),
237283
]
238284
)
239285
if enable_expert_parallel:
@@ -243,29 +289,6 @@ def _build_vllm_cmd(
243289
return vllm_cmd
244290

245291

246-
# probably not a great idea due to potential expert load imbalance issues,
247-
# but worth a test
248-
@app.cls(
249-
image=image,
250-
gpu="H100:8",
251-
volumes={
252-
"/root/.cache/huggingface": hf_cache_volume,
253-
"/root/.cache/vllm": vllm_cache_volume,
254-
},
255-
timeout=60 * 60 * 1,
256-
experimental_options={"flash": "us-east"},
257-
)
258-
@modal.experimental.clustered(size=2, rdma=True)
259-
class K2Tp8Ep2(K2Inference):
260-
# 2x8H100
261-
# tp=8,pp=1,ep=2,dp=1
262-
tp_size = 8
263-
pp_size = 1
264-
dp_size = 1
265-
max_seqs = 256
266-
enable_expert_parallel = True
267-
268-
269292
@app.cls(
270293
image=image,
271294
gpu="H100:8",
@@ -277,14 +300,15 @@ class K2Tp8Ep2(K2Inference):
277300
experimental_options={"flash": "us-east"},
278301
)
279302
@modal.experimental.clustered(size=4, rdma=True)
280-
class K2Tp8Ep4(K2Inference):
281-
# low-latency
303+
class K2Tp8Dp2Ep2(K2Inference):
282304
# 4x8H100
283-
# tp=8,pp=1,ep=4,dp=1
305+
# tp=8,pp=1,ep=2,dp=2
284306
tp_size = 8
285307
pp_size = 1
286-
dp_size = 1
287-
max_seqs = 256
308+
dp_size = 2
309+
nodes = 4
310+
max_seqs = 4
311+
max_model_len = 48000
288312
enable_expert_parallel = True
289313

290314

@@ -299,13 +323,13 @@ class K2Tp8Ep4(K2Inference):
299323
experimental_options={"flash": "us-east"},
300324
)
301325
@modal.experimental.clustered(size=4, rdma=True)
302-
class K2Tp8Ep2Dp2(K2Inference):
303-
# high throughput
326+
class K2Tp8Pp2Ep2(K2Inference):
304327
# 4x8H100
305-
# tp=8,pp=1,ep=2,dp=2
328+
# tp=8,pp=2,ep=2,dp=1
306329
tp_size = 8
307-
pp_size = 1
308-
dp_size = 2
330+
pp_size = 2
331+
dp_size = 1
332+
nodes = 4
309333
max_seqs = 256
334+
max_model_len = 64000
310335
enable_expert_parallel = True
311-
>>>>>>> Conflict 2 of 2 ends

0 commit comments

Comments
 (0)