Skip to content

Commit 167ec42

Browse files
authored
[ascend] add env to set rt visable by ray and disable warmup (#3894)
1 parent 73f01a9 commit 167ec42

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,8 @@ def prepare_inputs_for_generation(
121121
inputs_embeds=inputs_embeds,
122122
context=context,
123123
)
124+
125+
def get_capture_batch_sizes(self) -> List[int]:
126+
"""Capture batch sizes."""
127+
# TODO: disable warmup now.
128+
return []

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010

11+
from lmdeploy.pytorch import envs as _envs
1112
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
1213
from lmdeploy.utils import get_logger
1314

@@ -339,5 +340,6 @@ def device_count():
339340
@staticmethod
340341
def support_ray():
341342
"""Support ray."""
342-
os.environ['RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES'] = '1'
343+
if not _envs.ascend_set_rt_visable_devices_by_ray:
344+
os.environ['RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES'] = '1'
343345
return True

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,17 +560,21 @@ def _init_distributed_environment_by_device(self, device_str: str):
560560
def _init_ascend_distributed_environment(self, driver_ip):
561561
"""Init ascend distributed environment."""
562562
rank_table_file = _envs.ascend_rank_table_file
563-
if not rank_table_file:
564-
# if rank table file is not set, treat as single node
565-
self.workers = self._sort_workers(driver_ip, self.workers)
566-
# simply set device by index, this is for single node, multiple devices
567-
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
568-
else:
563+
set_rt_visable_devices_by_ray = _envs.ascend_set_rt_visable_devices_by_ray
564+
565+
if rank_table_file:
569566
# if rank table file is set, use it to get rank mapping, multiple nodes
570567
rank_mapping, worker_ips, envs = get_ascend_device_rank_mapping(driver_ip)
571568
self.workers = self._sort_workers_by_ip(worker_ips, self.workers)
572569
ray.get([worker.set_device.remote(rank_mapping[idx]) for idx, worker in enumerate(self.workers)])
573570
ray.get([worker.set_env.remote(envs) for worker in self.workers])
571+
elif not set_rt_visable_devices_by_ray:
572+
# if rank table file is not set, treat as single node
573+
# simply set device by index, this is for single node, multiple devices
574+
self.workers = self._sort_workers(driver_ip, self.workers)
575+
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
576+
else:
577+
self.workers = self._sort_workers(driver_ip, self.workers)
574578

575579
""" PD Disaggregation API Begin """
576580

lmdeploy/pytorch/envs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def _patched_get_env(
8787
ray_nsys_output_prefix = os.getenv('LMDEPLOY_RAY_NSYS_OUT_PREFIX', None)
8888

8989
# ascend
90+
ascend_set_rt_visable_devices_by_ray = env_to_bool('ASCEND_SET_RT_VISIBLE_DEVICES_BY_RAY', False)
9091
ascend_rank_table_file = os.getenv('ASCEND_RANK_TABLE_FILE_PATH')
9192

9293
# dp

0 commit comments

Comments
 (0)