Skip to content

Commit a9c71ed

Browse files
MinaHuaidavidmlw
authored andcommitted
resolve ray conflit
1 parent b45cad7 commit a9c71ed

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def __init__(self,
164164
self.mpi_session = self.args.mpi_session
165165

166166
if self.args.parallel_config.is_multi_gpu:
167-
if get_device_count(
167+
if os.getenv("RAY_LOCAL_RANK") is None and get_device_count(
168168
) < self.args.parallel_config.world_size_per_node:
169169
raise RuntimeError(
170170
f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required."

tensorrt_llm/llmapi/llm_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,8 @@ def validate_quant_config(cls, v, info):
11941194
@field_validator("gpus_per_node", mode='before')
11951195
@classmethod
11961196
def validate_gpus_per_node(cls, v, info):
1197+
if os.getenv("RAY_LOCAL_RANK") is not None:
1198+
return info.data.get("tensor_parallel_size")
11971199
if v is None:
11981200
logger.warning(
11991201
f"Using default gpus_per_node: {torch.cuda.device_count()}")

0 commit comments

Comments
 (0)