Skip to content

Commit 4b15d0d

Browse files
fyrestone刘宝
andauthored
[Ray] Support worker_mem for ray executor (#3300)
* Remove subtask_id tag from ray dag metrics * Refine ray dag gc * Use set + list instead of OrderedSet * Support ray task memory * Add tests Co-authored-by: 刘宝 <[email protected]>
1 parent 3949506 commit 4b15d0d

File tree

4 files changed

+40
-12
lines changed

4 files changed

+40
-12
lines changed

mars/deploy/oscar/ray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def __init__(
410410
supervisor_mem: int = 1 * 1024**3,
411411
worker_num: int = 1,
412412
worker_cpu: Union[int, float] = 2,
413-
worker_mem: int = 4 * 1024**3,
413+
worker_mem: int = 2 * 1024**3,
414414
backend: str = None,
415415
config: Union[str, Dict] = None,
416416
n_supervisor_process: int = DEFAULT_SUPERVISOR_SUB_POOL_NUM,
@@ -467,6 +467,7 @@ async def start(self):
467467
n_cpu=self._worker_num * self._worker_cpu,
468468
mem_bytes=self._worker_mem,
469469
subtask_num_cpus=self._worker_cpu,
470+
subtask_memory=self._worker_mem,
470471
)
471472
)
472473
assert self._n_supervisor_process == 0, self._n_supervisor_process

mars/services/task/execution/ray/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def get_subtask_max_retries(self):
5454
def get_subtask_num_cpus(self) -> Union[int, float]:
5555
return self._ray_execution_config.get("subtask_num_cpus", 1)
5656

57+
def get_subtask_memory(self) -> Union[int, float]:
58+
return self._ray_execution_config.get("subtask_memory", None)
59+
5760
def get_n_cpu(self):
5861
return self._ray_execution_config["n_cpu"]
5962

mars/services/task/execution/ray/executor.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
)
3434
from .....core.operand.fetch import FetchShuffle
3535
from .....lib.aio import alru_cache
36-
from .....lib.ordered_set import OrderedSet
3736
from .....metrics.api import init_metrics, Metrics
3837
from .....resource import Resource
3938
from .....serialization import serialize, deserialize
@@ -76,12 +75,10 @@
7675
started_subtask_number = Metrics.counter(
7776
"mars.ray_dag.started_subtask_number",
7877
"The number of started subtask.",
79-
("subtask_id",),
8078
)
8179
completed_subtask_number = Metrics.counter(
8280
"mars.ray_dag.completed_subtask_number",
8381
"The number of completed subtask.",
84-
("subtask_id",),
8582
)
8683

8784

@@ -183,8 +180,7 @@ def execute_subtask(
183180
subtask outputs and meta for outputs if `output_meta_keys` is provided.
184181
"""
185182
init_metrics("ray")
186-
metrics_tags = {"subtask_id": subtask_id}
187-
started_subtask_number.record(1, metrics_tags)
183+
started_subtask_number.record(1)
188184
ray_task_id = ray.get_runtime_context().task_id
189185
subtask_chunk_graph = deserialize(*subtask_chunk_graph)
190186
logger.info("Start subtask: %s, ray task id: %s.", subtask_id, ray_task_id)
@@ -277,7 +273,7 @@ def execute_subtask(
277273
output_values.extend(normal_output.values())
278274
output_values.extend(mapper_output.values())
279275
logger.info("Complete subtask: %s, ray task id: %s.", subtask_id, ray_task_id)
280-
completed_subtask_number.record(1, metrics_tags)
276+
completed_subtask_number.record(1)
281277
return output_values[0] if len(output_values) == 1 else output_values
282278

283279

@@ -331,6 +327,32 @@ class _RayExecutionStage(enum.Enum):
331327
WAITING = 2
332328

333329

330+
class OrderedSet:
331+
def __init__(self):
332+
self._d = set()
333+
self._l = list()
334+
335+
def add(self, item):
336+
self._d.add(item)
337+
self._l.append(item)
338+
assert len(self._d) == len(self._l)
339+
340+
def update(self, items):
341+
tmp = list(items) if isinstance(items, collections.Iterator) else items
342+
self._l.extend(tmp)
343+
self._d.update(tmp)
344+
assert len(self._d) == len(self._l)
345+
346+
def __contains__(self, item):
347+
return item in self._d
348+
349+
def __getitem__(self, item):
350+
return self._l[item]
351+
352+
def __len__(self):
353+
return len(self._d)
354+
355+
334356
@dataclass
335357
class _RayMonitorContext:
336358
stage: _RayExecutionStage = _RayExecutionStage.INIT
@@ -576,6 +598,7 @@ async def _execute_subtask_graph(
576598
)
577599
subtask_max_retries = self._config.get_subtask_max_retries()
578600
subtask_num_cpus = self._config.get_subtask_num_cpus()
601+
subtask_memory = self._config.get_subtask_memory()
579602
metrics_tags = {
580603
"session_id": self._task.session_id,
581604
"task_id": self._task.task_id,
@@ -608,6 +631,7 @@ async def _execute_subtask_graph(
608631
num_cpus=subtask_num_cpus,
609632
num_returns=output_count,
610633
max_retries=subtask_max_retries,
634+
memory=subtask_memory,
611635
scheduling_strategy="DEFAULT" if len(input_object_refs) else "SPREAD",
612636
).remote(
613637
subtask.subtask_id,
@@ -840,11 +864,9 @@ def gc():
840864
for pred in subtask_graph.iter_predecessors(subtask):
841865
if pred in gc_subtasks:
842866
continue
843-
while not all(
844-
succ in gc_targets
845-
for succ in subtask_graph.iter_successors(pred)
846-
):
847-
yield
867+
for succ in subtask_graph.iter_successors(pred):
868+
while succ not in gc_targets:
869+
yield
848870
if pred.virtual:
849871
# For virtual subtask, remove all the predecessors if it is
850872
# completed.

mars/services/task/execution/ray/tests/test_ray_execution_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def options(cls, *args, **kwargs):
371371
"monitor_interval_seconds": 0,
372372
"subtask_max_retries": 4,
373373
"subtask_num_cpus": 0.8,
374+
"subtask_memory": 1001,
374375
"n_cpu": 1,
375376
"n_worker": 1,
376377
},
@@ -394,6 +395,7 @@ def options(cls, *args, **kwargs):
394395

395396
assert MockExecutor.opt["num_cpus"] == 0.8
396397
assert MockExecutor.opt["max_retries"] == 4
398+
assert MockExecutor.opt["memory"] == 1001
397399

398400

399401
@require_ray

0 commit comments

Comments
 (0)