Skip to content

Commit a63a1ac

Browse files
authored
[TRTLLM-6444] Add some UCX trouble shooting docs and print UCX related logs (NVIDIA#6085)
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 428e340 commit a63a1ac

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

docs/source/advanced/disaggregated-service.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ TRT-LLM uses some environment variables to control the behavior of disaggregated
3232

3333
* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `4`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.
3434

35+
There are some other useful environment variables that may help when encountering failures or performance issues.
36+
37+
* `NCCL_GRAPH_MIXING_SUPPORT`: With the default value `1`, the CUDA driver may create too many CUDA streams while working with one CUDA graph, leading to performance drop. Setting it to `0` will reduce the number of CUDA streams, but please make sure there are no other NCCL ops outside the one CUDA graph, otherwise it's unsafe.
38+
3539
## Troubleshooting and FAQ
3640

3741
### General FAQs
@@ -80,3 +84,13 @@ A. Yes, TRT-LLM supports using GPU direct RDMA for inter-node KV cache transfer.
8084
*Q. What causes the substantial bandwidth fluctuations in kvCache transfers, especially during the first few requests following service initialization?*
8185
8286
A. The communication for kvCache transfer between executors are established dynamically. The connection establishment process incurs significant overhead, which explains the apparently lower kvCache transfer bandwidth observed during the initial requests after service startup. This lower bandwidth reflects the inclusion of connection establishment overhead. When conducting benchmarks, it is recommended to perform a warm-up phase to ensure accurate performance measurements.
87+
88+
*Q. When my servers are running on different NVLink domains, some servers hang or have a lower performance. How to fix that?
89+
90+
A. NVLink domain can be found with `nvidia-smi -q` in the `Fabric.ClusterUUID` field. A few UCX environment variables can be adjusted when your servers have different NVLink domains:
91+
92+
* `UCX_CUDA_IPC_ENABLE_MNNVL`: Set to `n`. This also can reduce UCX timeout error messages like `UCX ERROR cuMemImportFromShareableHandle failed: invalid resource handle`, although these errors don't necessarily cause your trtllm-serve to fail.
93+
94+
* `UCX_NET_DEVICES`: Check if this is set correctly, or unset this variable to allow UCX to use all possible devices.
95+
96+
* `UCX_RNDV_SCHEME`: Set to `get_zcopy` or `put_zcopy` on GB200 for better performance. The default value is `auto`.

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,36 @@ def create_kv_cache_transceiver(
3131
mapping: Mapping, kv_cache_manager: KVCacheManager,
3232
attention_type: AttentionTypeCpp,
3333
cache_transceiver_config: CacheTransceiverConfig):
34-
if cache_transceiver_config is None or (cache_transceiver_config.backend
35-
is None):
34+
if cache_transceiver_config is None or cache_transceiver_config.backend is None:
3635
logger.info("cache_transceiver is disabled")
3736
return None
38-
if (cache_transceiver_config.backend == BackendTypeCpp.DEFAULT):
39-
40-
backend_type = BackendTypeCpp.UCX
41-
if getenv("TRTLLM_USE_UCX_KVCACHE"):
42-
backend_type = BackendTypeCpp.UCX
43-
elif getenv("TRTLLM_USE_NIXL_KVCACHE"):
44-
backend_type = BackendTypeCpp.NIXL
45-
elif getenv("TRTLLM_USE_MPI_KVCACHE"):
46-
backend_type = BackendTypeCpp.MPI
47-
cache_transceiver_config.backend = backend_type
48-
49-
if (cache_transceiver_config.backend == BackendTypeCpp.MPI):
37+
38+
if cache_transceiver_config.backend == BackendTypeCpp.DEFAULT:
39+
# When cache_transceiver_config.backend is not set, fallback to env_vars settings
40+
# UCX is the default backend
41+
cache_transceiver_config.backend = BackendTypeCpp.UCX
42+
# Ordered by priority
43+
env_vars = [("TRTLLM_USE_NIXL_KVCACHE", BackendTypeCpp.NIXL),
44+
("TRTLLM_USE_MPI_KVCACHE", BackendTypeCpp.MPI)]
45+
for env_var, be_type in env_vars:
46+
if getenv(env_var) == "1":
47+
logger.warning(
48+
f"{env_var}=1 is set, but it's recommended to set cache_transceiver_config.backend in yaml config"
49+
)
50+
cache_transceiver_config.backend = be_type
51+
break
52+
53+
if cache_transceiver_config.backend == BackendTypeCpp.MPI:
5054
logger.warning(
5155
"MPI CacheTransceiver is deprecated, UCX or NIXL is recommended")
52-
cache_transceiver = BindKvCacheTransceiver(mapping, kv_cache_manager,
53-
attention_type,
54-
cache_transceiver_config)
55-
56-
return cache_transceiver
56+
elif cache_transceiver_config.backend == BackendTypeCpp.UCX:
57+
logger.info(
58+
f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting "
59+
f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server "
60+
f"hangs or lower-than-expected performance.")
61+
62+
return BindKvCacheTransceiver(mapping, kv_cache_manager, attention_type,
63+
cache_transceiver_config)
5764

5865

5966
class KvCacheTransceiver(ABC):

0 commit comments

Comments
 (0)